package org.tensorflow.framework.op;

import static org.junit.jupiter.api.Assertions.assertThrows;

import org.junit.jupiter.api.Test;
import org.tensorflow.Operand;
import org.tensorflow.framework.utils.TestSession;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Ops;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TFloat64;
import org.tensorflow.types.TInt64;

class MathOpsTest {

  private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};

  double[][][] array =
      new double[][][] {
        {
          {4.17021990e-01, 7.20324516e-01, 1.14374816e-04},
          {3.02332580e-01, 1.46755889e-01, 9.23385918e-02},
          {1.86260208e-01, 3.45560730e-01, 3.96767467e-01},
          {5.38816750e-01, 4.19194520e-01, 6.85219526e-01},
          {2.04452246e-01, 8.78117442e-01, 2.73875929e-02},
          {6.70467496e-01, 4.17304814e-01, 5.58689833e-01},
          {1.40386939e-01, 1.98101491e-01, 8.00744593e-01}
        },
        {
          {9.68261600e-01, 3.13424170e-01, 6.92322612e-01},
          {8.76389146e-01, 8.94606650e-01, 8.50442126e-02},
          {3.90547849e-02, 1.69830427e-01, 8.78142476e-01},
          {9.83468369e-02, 4.21107620e-01, 9.57889557e-01},
          {5.33165276e-01, 6.91877127e-01, 3.15515637e-01},
          {6.86500907e-01, 8.34625661e-01, 1.82882771e-02},
          {7.50144303e-01, 9.88861084e-01, 7.48165667e-01}
        },
        {
          {2.80443996e-01, 7.89279342e-01, 1.03226006e-01},
          {4.47893530e-01, 9.08595502e-01, 2.93614149e-01},
          {2.87775338e-01, 1.30028576e-01, 1.93669572e-02},
          {6.78835511e-01, 2.11628109e-01, 2.65546650e-01},
          {4.91573155e-01, 5.33625446e-02, 5.74117601e-01},
          {1.46728575e-01, 5.89305520e-01, 6.99758351e-01},
          {1.02334432e-01, 4.14055973e-01, 6.94400132e-01}
        },
        {
          {4.14179265e-01, 4.99534607e-02, 5.35896420e-01},
          {6.63794637e-01, 5.14889121e-01, 9.44594741e-01},
          {5.86555064e-01, 9.03401911e-01, 1.37474701e-01},
          {1.39276341e-01, 8.07391286e-01, 3.97676826e-01},
          {1.65354192e-01, 9.27508593e-01, 3.47765863e-01},
          {7.50812113e-01, 7.25997984e-01, 8.83306086e-01},
          {6.23672187e-01, 7.50942409e-01, 3.48898351e-01}
        },
        {
          {2.69927889e-01, 8.95886242e-01, 4.28091198e-01},
          {9.64840055e-01, 6.63441479e-01, 6.21695697e-01},
          {1.14745975e-01, 9.49489236e-01, 4.49912131e-01},
          {5.78389585e-01, 4.08136815e-01, 2.37026975e-01},
          {9.03379500e-01, 5.73679507e-01, 2.87032709e-03},
          {6.17144942e-01, 3.26644897e-01, 5.27058125e-01},
          {8.85942101e-01, 3.57269764e-01, 9.08535123e-01}
        },
        {
          {6.23360097e-01, 1.58212427e-02, 9.29437220e-01},
          {6.90896928e-01, 9.97322857e-01, 1.72340512e-01},
          {1.37135744e-01, 9.32595491e-01, 6.96818173e-01},
          {6.60001710e-02, 7.55463064e-01, 7.53876209e-01},
          {9.23024535e-01, 7.11524785e-01, 1.24270961e-01},
          {1.98801346e-02, 2.62109861e-02, 2.83064879e-02},
          {2.46211067e-01, 8.60027969e-01, 5.38831055e-01}
        },
        {
          {5.52821994e-01, 8.42030883e-01, 1.24173313e-01},
          {2.79183686e-01, 5.85759282e-01, 9.69595730e-01},
          {5.61030209e-01, 1.86472889e-02, 8.00632656e-01},
          {2.32974276e-01, 8.07105184e-01, 3.87860656e-01},
          {8.63541842e-01, 7.47121632e-01, 5.56240261e-01},
          {1.36455223e-01, 5.99176884e-02, 1.21343456e-01},
          {4.45518792e-02, 1.07494131e-01, 2.25709334e-01}
        },
        {
          {7.12988973e-01, 5.59717000e-01, 1.25559801e-02},
          {7.19742775e-02, 9.67276335e-01, 5.68100452e-01},
          {2.03293234e-01, 2.52325743e-01, 7.43825853e-01},
          {1.95429474e-01, 5.81358910e-01, 9.70019996e-01},
          {8.46828818e-01, 2.39847764e-01, 4.93769705e-01},
          {6.19955719e-01, 8.28980923e-01, 1.56791389e-01},
          {1.85762029e-02, 7.00221434e-02, 4.86345112e-01}
        },
        {
          {6.06329441e-01, 5.68851411e-01, 3.17362398e-01},
          {9.88616168e-01, 5.79745233e-01, 3.80141169e-01},
          {5.50948203e-01, 7.45334446e-01, 6.69232905e-01},
          {2.64919549e-01, 6.63348362e-02, 3.70084196e-01},
          {6.29717529e-01, 2.10174009e-01, 7.52755582e-01},
          {6.65364787e-02, 2.60315090e-01, 8.04754555e-01},
          {1.93434283e-01, 6.39460862e-01, 5.24670303e-01}
        },
        {
          {9.24807966e-01, 2.63296783e-01, 6.59610927e-02},
          {7.35065937e-01, 7.72178054e-01, 9.07815874e-01},
          {9.31972086e-01, 1.39515726e-02, 2.34362081e-01},
          {6.16778374e-01, 9.49016333e-01, 9.50176120e-01},
          {5.56653202e-01, 9.15606380e-01, 6.41566217e-01},
          {3.90007704e-01, 4.85990673e-01, 6.04310513e-01},
          {5.49547911e-01, 9.26181436e-01, 9.18733418e-01}
        },
        {
          {3.94875616e-01, 9.63262558e-01, 1.73955664e-01},
          {1.26329526e-01, 1.35079160e-01, 5.05662143e-01},
          {2.15248056e-02, 9.47970212e-01, 8.27115476e-01},
          {1.50189810e-02, 1.76196262e-01, 3.32063586e-01},
          {1.30996838e-01, 8.09490681e-01, 3.44736665e-01},
          {9.40107465e-01, 5.82014203e-01, 8.78831983e-01},
          {8.44734430e-01, 9.05392289e-01, 4.59880263e-01}
        },
        {
          {5.46346843e-01, 7.98603594e-01, 2.85718858e-01},
          {4.90253508e-01, 5.99110305e-01, 1.55332759e-02},
          {5.93481421e-01, 4.33676362e-01, 8.07360530e-01},
          {3.15244794e-01, 8.92888725e-01, 5.77857196e-01},
          {1.84010208e-01, 7.87929237e-01, 6.12031162e-01},
          {5.39092720e-02, 4.20193672e-01, 6.79068863e-01},
          {9.18601751e-01, 4.02024889e-04, 9.76759136e-01}
        },
        {
          {3.76580328e-01, 9.73783553e-01, 6.04716122e-01},
          {8.28845799e-01, 5.74711502e-01, 6.28076196e-01},
          {2.85576284e-01, 5.86833358e-01, 7.50021756e-01},
          {8.58313859e-01, 7.55082190e-01, 6.98057234e-01},
          {8.64479423e-01, 3.22681010e-01, 6.70788765e-01},
          {4.50873941e-01, 3.82102758e-01, 4.10811365e-01},
          {4.01479572e-01, 3.17383945e-01, 6.21919394e-01}
        },
        {
          {4.30247277e-01, 9.73802090e-01, 6.77800894e-01},
          {1.98569894e-01, 4.26701009e-01, 3.43346238e-01},
          {7.97638834e-01, 8.79998267e-01, 9.03841972e-01},
          {6.62719786e-01, 2.70208269e-01, 2.52366692e-01},
          {8.54897916e-01, 5.27714670e-01, 8.02161098e-01},
          {5.72488546e-01, 7.33142555e-01, 5.19011617e-01},
          {7.70883918e-01, 5.68857968e-01, 4.65709865e-01}
        },
        {
          {3.42688918e-01, 6.82093501e-02, 3.77924174e-01},
          {7.96260759e-02, 9.82817113e-01, 1.81612849e-01},
          {8.11858714e-01, 8.74961674e-01, 6.88413262e-01},
          {5.69494426e-01, 1.60971433e-01, 4.66880023e-01},
          {3.45172048e-01, 2.25039959e-01, 5.92511892e-01},
          {3.12269837e-01, 9.16305542e-01, 9.09635544e-01},
          {2.57118285e-01, 1.10891297e-01, 1.92962736e-01}
        },
        {
          {4.99584168e-01, 7.28585660e-01, 2.08194435e-01},
          {2.48033553e-01, 8.51671875e-01, 4.15848732e-01},
          {6.16685092e-01, 2.33666137e-01, 1.01967260e-01},
          {5.15857041e-01, 4.77140993e-01, 1.52671650e-01},
          {6.21806204e-01, 5.44010103e-01, 6.54137373e-01},
          {1.44545540e-01, 7.51527846e-01, 2.22049147e-01},
          {5.19351840e-01, 7.85296023e-01, 2.23304275e-02}
        },
        {
          {3.24362457e-01, 8.72922361e-01, 8.44709635e-01},
          {5.38440585e-01, 8.66608262e-01, 9.49805975e-01},
          {8.26407015e-01, 8.54115427e-01, 9.87434015e-02},
          {6.51304305e-01, 7.03516960e-01, 6.10240817e-01},
          {7.99615264e-01, 3.45712192e-02, 7.70238757e-01},
          {7.31728613e-01, 2.59698391e-01, 2.57069290e-01},
          {6.32303298e-01, 3.45297456e-01, 7.96588659e-01}
        },
        {
          {4.46146220e-01, 7.82749414e-01, 9.90471780e-01},
          {3.00248325e-01, 1.43005833e-01, 9.01308417e-01},
          {5.41559398e-01, 9.74740386e-01, 6.36604428e-01},
          {9.93912995e-01, 5.46070814e-01, 5.26425958e-01},
          {1.35427907e-01, 3.55705172e-01, 2.62185670e-02},
          {1.60395175e-01, 7.45637178e-01, 3.03996895e-02},
          {3.66543084e-01, 8.62346232e-01, 6.92677736e-01}
        },
        {
          {6.90942168e-01, 1.88636795e-01, 4.41904277e-01},
          {5.81577420e-01, 9.89751697e-01, 2.03906223e-01},
          {2.47732908e-01, 2.62173086e-01, 7.50172436e-01},
          {4.56975341e-01, 5.69294393e-02, 5.08516252e-01},
          {2.11960167e-01, 7.98604250e-01, 2.97331393e-01},
          {2.76060123e-02, 5.93432426e-01, 8.43840420e-01},
          {3.81016135e-01, 7.49858320e-01, 5.11141479e-01}
        },
        {
          {5.40951788e-01, 9.59434330e-01, 8.03960919e-01},
          {3.23230661e-02, 7.09387243e-01, 4.65001494e-01},
          {9.47548926e-01, 2.21432731e-01, 2.67072022e-01},
          {8.14739615e-02, 4.28618819e-01, 1.09018765e-01},
          {6.33786738e-01, 8.02963257e-01, 6.96800470e-01},
          {7.66211390e-01, 3.42454106e-01, 8.45851481e-01},
          {4.28768784e-01, 8.24009895e-01, 6.26496136e-01}
        }
      };

  double[][][] expectedArray = {
    {
      {3.45350616e-02, 5.96526116e-02, 9.47178160e-06},
      {2.50372272e-02, 1.21533722e-02, 7.64688430e-03},
      {1.54248644e-02, 2.86171008e-02, 3.28577124e-02},
      {4.46213149e-02, 3.47149745e-02, 5.67454435e-02},
      {1.69314109e-02, 7.27199987e-02, 2.26806314e-03},
      {5.55237755e-02, 3.45584825e-02, 4.62670736e-02},
      {1.16259372e-02, 1.64054818e-02, 6.63124844e-02}
    },
    {
      {8.01851526e-02, 2.59557609e-02, 5.73336743e-02},
      {7.25768730e-02, 7.40855262e-02, 7.04281079e-03},
      {3.23426444e-03, 1.40642561e-02, 7.27220699e-02},
      {8.14444851e-03, 3.48734073e-02, 7.93262124e-02},
      {4.41532955e-02, 5.72967827e-02, 2.61289626e-02},
      {5.68515584e-02, 6.91182911e-02, 1.51451665e-03},
      {6.21220917e-02, 8.18910673e-02, 6.19582348e-02}
    },
    {
      {2.32245550e-02, 6.53630048e-02, 8.54850933e-03},
      {3.70916426e-02, 7.52439946e-02, 2.43152231e-02},
      {2.38316897e-02, 1.07681248e-02, 1.60384597e-03},
      {5.62167615e-02, 1.75256692e-02, 2.19908543e-02},
      {4.07089069e-02, 4.41914052e-03, 4.75447029e-02},
      {1.21511100e-02, 4.88024652e-02, 5.79494536e-02},
      {8.47467501e-03, 3.42894346e-02, 5.75057231e-02}
    },
    {
      {3.42996456e-02, 4.13682219e-03, 4.43794727e-02},
      {5.49711734e-02, 4.26397808e-02, 7.82252178e-02},
      {4.85746935e-02, 7.48138949e-02, 1.13847647e-02},
      {1.15339644e-02, 6.68629184e-02, 3.29330191e-02},
      {1.36935636e-02, 7.68102556e-02, 2.87997164e-02},
      {6.21773973e-02, 6.01224527e-02, 7.31496885e-02},
      {5.16484901e-02, 6.21881858e-02, 2.88935024e-02}
    },
    {
      {2.23536789e-02, 7.41914958e-02, 3.54517400e-02},
      {7.99018070e-02, 5.49419262e-02, 5.14848121e-02},
      {9.50251892e-03, 7.86305517e-02, 3.72588076e-02},
      {4.78984788e-02, 3.37992460e-02, 1.96290389e-02},
      {7.48120397e-02, 4.75084223e-02, 2.37701897e-04},
      {5.11079468e-02, 2.70506144e-02, 4.36475389e-02},
      {7.33679906e-02, 2.95867678e-02, 7.52389953e-02}
    },
    {
      {5.16226478e-02, 1.31021289e-03, 7.69699737e-02},
      {5.72156087e-02, 8.25918168e-02, 1.42721254e-02},
      {1.13566946e-02, 7.72315189e-02, 5.77059686e-02},
      {5.46570681e-03, 6.25625551e-02, 6.24311455e-02},
      {7.64389113e-02, 5.89238741e-02, 1.02913165e-02},
      {1.64634397e-03, 2.17062421e-03, 2.34416011e-03},
      {2.03896053e-02, 7.12219477e-02, 4.46224995e-02}
    },
    {
      {4.57811356e-02, 6.97315410e-02, 1.02832299e-02},
      {2.31201854e-02, 4.85087894e-02, 8.02956372e-02},
      {4.64608893e-02, 1.54424773e-03, 6.63032085e-02},
      {1.92934200e-02, 6.68392256e-02, 3.21201086e-02},
      {7.15129450e-02, 6.18717745e-02, 4.60642166e-02},
      {1.13003375e-02, 4.96199494e-03, 1.00488793e-02},
      {3.68949817e-03, 8.90196767e-03, 1.86917856e-02}
    },
    {
      {5.90451285e-02, 4.63521369e-02, 1.03980501e-03},
      {5.96044352e-03, 8.01035613e-02, 4.70464006e-02},
      {1.68354288e-02, 2.08959840e-02, 6.15988411e-02},
      {1.61842033e-02, 4.81443815e-02, 8.03307742e-02},
      {7.01288804e-02, 1.98626388e-02, 4.08908091e-02},
      {5.13407178e-02, 6.86508343e-02, 1.29844472e-02},
      {1.53836084e-03, 5.79878036e-03, 4.02759537e-02}
    },
    {
      {5.02122790e-02, 4.71085906e-02, 2.62818988e-02},
      {8.18707868e-02, 4.80107442e-02, 3.14808302e-02},
      {4.56259623e-02, 6.17237724e-02, 5.54215349e-02},
      {2.19389219e-02, 5.49342157e-03, 3.06479763e-02},
      {5.21491282e-02, 1.74052510e-02, 6.23383410e-02},
      {5.51012019e-03, 2.15576105e-02, 6.66445568e-02},
      {1.60189737e-02, 5.29560074e-02, 4.34497967e-02}
    },
    {
      {7.65866041e-02, 2.18045339e-02, 5.46247046e-03},
      {6.08734004e-02, 6.39467835e-02, 7.51794279e-02},
      {7.71798939e-02, 1.15537888e-03, 1.94083489e-02},
      {5.10775894e-02, 7.85913840e-02, 7.86874294e-02},
      {4.60984148e-02, 7.58245885e-02, 5.31303585e-02},
      {3.22979130e-02, 4.02465984e-02, 5.00450842e-02},
      {4.55099978e-02, 7.67003447e-02, 7.60835484e-02}
    },
    {
      {3.27010415e-02, 7.97711685e-02, 1.44058811e-02},
      {1.04617933e-02, 1.11863809e-02, 4.18756641e-02},
      {1.78254500e-03, 7.85047561e-02, 6.84963465e-02},
      {1.24377478e-03, 1.45914331e-02, 2.74993554e-02},
      {1.08483098e-02, 6.70367777e-02, 2.85488572e-02},
      {7.78536126e-02, 4.81986478e-02, 7.27791712e-02},
      {6.99554384e-02, 7.49787241e-02, 3.80843058e-02}
    },
    {
      {4.52449061e-02, 6.61351755e-02, 2.36613862e-02},
      {4.05996218e-02, 4.96144369e-02, 1.28636532e-03},
      {4.91482876e-02, 3.59142683e-02, 6.68603703e-02},
      {2.61065327e-02, 7.39432648e-02, 4.78543900e-02},
      {1.52385337e-02, 6.52511939e-02, 5.06844558e-02},
      {4.46441676e-03, 3.47977169e-02, 5.62360846e-02},
      {7.60726482e-02, 3.32930977e-05, 8.08888674e-02}
    },
    {
      {3.11859436e-02, 8.06424469e-02, 5.00786714e-02},
      {6.86396435e-02, 4.75938842e-02, 5.20132035e-02},
      {2.36495789e-02, 4.85977381e-02, 6.21119440e-02},
      {7.10799918e-02, 6.25310168e-02, 5.78085780e-02},
      {7.15905875e-02, 2.67223511e-02, 5.55503815e-02},
      {3.73384580e-02, 3.16432752e-02, 3.40207368e-02},
      {3.32479365e-02, 2.62836833e-02, 5.15033379e-02}
    },
    {
      {3.56302932e-02, 8.06439817e-02, 5.61310798e-02},
      {1.64442733e-02, 3.53366137e-02, 2.84337122e-02},
      {6.60552830e-02, 7.28757605e-02, 7.48503357e-02},
      {5.48821613e-02, 2.23768987e-02, 2.08993759e-02},
      {7.07971081e-02, 4.37019095e-02, 6.64297864e-02},
      {4.74097952e-02, 6.07141182e-02, 4.29811813e-02},
      {6.38396144e-02, 4.71091345e-02, 3.85670736e-02}
    },
    {
      {2.83792764e-02, 5.64865675e-03, 3.12972330e-02},
      {6.59411587e-03, 8.13905448e-02, 1.50400000e-02},
      {6.72328845e-02, 7.24586621e-02, 5.70099279e-02},
      {4.71618399e-02, 1.33306114e-02, 3.86639796e-02},
      {2.85849143e-02, 1.86363515e-02, 4.90679964e-02},
      {2.58601662e-02, 7.58824944e-02, 7.53301233e-02},
      {2.12928709e-02, 9.18329880e-03, 1.59799233e-02}
    },
    {
      {4.13723253e-02, 6.03367463e-02, 1.72413141e-02},
      {2.05405317e-02, 7.05299526e-02, 3.44378985e-02},
      {5.10698669e-02, 1.93507168e-02, 8.44426826e-03},
      {4.27199379e-02, 3.95137258e-02, 1.26432776e-02},
      {5.14939614e-02, 4.50513922e-02, 5.41714206e-02},
      {1.19703254e-02, 6.22366704e-02, 1.83886718e-02},
      {4.30093557e-02, 6.50331303e-02, 1.84926135e-03}
    },
    {
      {2.68615987e-02, 7.22897798e-02, 6.99533820e-02},
      {4.45901640e-02, 7.17668831e-02, 7.86567777e-02},
      {6.84376806e-02, 7.07323104e-02, 8.17728881e-03},
      {5.39368056e-02, 5.82607202e-02, 5.05361930e-02},
      {6.62189573e-02, 2.86296452e-03, 6.37861863e-02},
      {6.05970249e-02, 2.15065386e-02, 2.12888140e-02},
      {5.23632653e-02, 2.85952985e-02, 6.59683123e-02}
    },
    {
      {3.69469412e-02, 6.48222342e-02, 8.20244551e-02},
      {2.48646215e-02, 1.18428171e-02, 7.46405274e-02},
      {4.48484421e-02, 8.07216838e-02, 5.27194552e-02},
      {8.23094398e-02, 4.52220477e-02, 4.35951874e-02},
      {1.12152621e-02, 2.94571985e-02, 2.17125192e-03},
      {1.32828895e-02, 6.17488436e-02, 2.51750532e-03},
      {3.03547252e-02, 7.14139268e-02, 5.73630854e-02}
    },
    {
      {5.72193563e-02, 1.56216780e-02, 3.65956500e-02},
      {4.81624752e-02, 8.19648281e-02, 1.68861933e-02},
      {2.05156356e-02, 2.17114780e-02, 6.21244237e-02},
      {3.78437378e-02, 4.71452763e-03, 4.21120226e-02},
      {1.75531674e-02, 6.61352351e-02, 2.46230606e-02},
      {2.28615105e-03, 4.91442308e-02, 6.98814020e-02},
      {3.15532871e-02, 6.20984100e-02, 4.23294269e-02}
    },
    {
      {4.47981246e-02, 7.94541389e-02, 6.65788352e-02},
      {2.67678709e-03, 5.87468557e-02, 3.85084115e-02},
      {7.84698650e-02, 1.83376241e-02, 2.21171752e-02},
      {6.74714567e-03, 3.54954340e-02, 9.02822800e-03},
      {5.24861142e-02, 6.64962158e-02, 5.77045009e-02},
      {6.34526685e-02, 2.83598304e-02, 7.00479448e-02},
      {3.55078541e-02, 6.82391599e-02, 5.18823527e-02}
    }
  };

  @Test
  public void testL2Normalize() {
    for (TestSession.Mode tfMode : tfModes)
      try (TestSession session = TestSession.createTestSession(tfMode)) {
        Ops tf = session.getTF();
        FrameworkOps fops = FrameworkOps.create(tf);
        Operand<TFloat64> input = tf.constant(array);
        Operand<TFloat64> result = fops.math.l2Normalize(tf.constant(array), new int[] {0, 1, 2});
        session.evaluate(tf.constant(expectedArray), result);
      }
  }

  @Test
  public void testConfusionMatrix() {
    for (TestSession.Mode tfMode : tfModes)
      try (TestSession session = TestSession.createTestSession(tfMode)) {
        Ops tf = session.getTF();
        FrameworkOps fops = FrameworkOps.create(tf);
        long[] labels = new long[] {2, 0, 2, 2, 0, 1};
        long[] predictions = new long[] {0, 0, 2, 2, 0, 2};
        Operand<TInt64> result =
            fops.math.confusionMatrix(tf.constant(labels), tf.constant(predictions));
        long[][] expected =
            new long[][] {
              {2, 0, 0},
              {0, 0, 1},
              {1, 0, 2}
            };
        session.evaluate(tf.constant(expected), result);
      }
  }

  @Test
  public void testTensorDotValid() {
    for (TestSession.Mode tfMode : tfModes)
      try (TestSession session = TestSession.createTestSession(tfMode)) {
        Ops tf = session.getTF();
        FrameworkOps fops = FrameworkOps.create(tf);
        int[] axes1 = new int[] {1, 2};
        int[][] axes2 = new int[][] {{1}, {2}};
        int[][] axes3 = new int[2][0];
        int axes4 = 0;

        Operand<TFloat32> a = tf.ones(tf.constant(Shape.of(3, 3)), TFloat32.class);
        Operand<TFloat32> b = tf.constant(new float[][][] {{{2, 3, 1}}});

        Operand<TFloat32> ans = fops.math.tensordot(a, b, axes1);
        Operand<TFloat32> expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}});
        session.evaluate(expected, ans);

        ans = fops.math.tensordot(a, b, axes2);
        expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}});
        session.evaluate(expected, ans);

        float[][][][][] expectedArray =
            new float[][][][][] {
              {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}},
              {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}},
              {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}
            };
        ans = fops.math.tensordot(a, b, axes3);
        expected = tf.constant(expectedArray);
        session.evaluate(expected, ans);

        ans = fops.math.tensordot(a, b, axes4);
        expected = tf.constant(expectedArray);
        session.evaluate(expected, ans);
      }
  }

  @Test
  public void testTensorDotInValidAxis() {
    for (TestSession.Mode tfMode : tfModes)
      try (TestSession session = TestSession.createTestSession(tfMode)) {
        Ops tf = session.getTF();
        FrameworkOps fops = FrameworkOps.create(tf);
        Operand<TFloat32> a = tf.constant(new float[][] {{1, 2}, {3, 4}});
        Operand<TFloat32> b = tf.constant(new float[][] {{1, 2}, {3, 4}});
        assertThrows(IllegalArgumentException.class, () -> fops.math.tensordot(a, b, -1));
        assertThrows(IllegalArgumentException.class, () -> fops.math.tensordot(a, b, 3));
        assertThrows(
            IllegalArgumentException.class, () -> fops.math.tensordot(a, b, new int[] {1}));
        assertThrows(
            IllegalArgumentException.class, () -> fops.math.tensordot(a, b, new int[][] {{1}}));
        assertThrows(
            IllegalArgumentException.class,
            () -> fops.math.tensordot(a, b, new int[][] {{1}, {0, 1}}));

        assertThrows(
            ArrayIndexOutOfBoundsException.class,
            () -> fops.math.tensordot(a, b, new int[][] {{0}, {7}}));
      }
  }

  @Test
  public void testReduceLogSumExp() {
    for (TestSession.Mode tfMode : tfModes)
      try (TestSession session = TestSession.createTestSession(tfMode)) {
        Ops tf = session.getTF();
        FrameworkOps fops = FrameworkOps.create(tf);
        Operand<TFloat32> x =
            tf.constant(
                new float[][] {
                  {0.43346116f, 0.8569728f, 0.57155997f, 0.0743812f, 0.63846475f},
                  {0.8165283f, 0.26554802f, 0.37025765f, 0.8255019f, 0.45682374f},
                  {0.93511814f, 0.52291054f, 0.80983895f, 0.11580781f, 0.8111686f},
                  {0.49967498f, 0.27537802f, 0.48554695f, 0.28238368f, 0.7989301f},
                  {0.8958915f, 0.84870094f, 0.56874424f, 0.08818512f, 0.13915819f}
                });

        Operand<TFloat32> result = fops.math.reduceLogSumExp(x, new int[] {0, 1}, false);
        session.evaluate(3.7911222f, result);
      }
  }
}
